Summary¶

Comparison notebook looking at the difference in predictions between a baseline LSTM model and the same model with hyperparameter tuning via Optuna. CRPS score was chosen as the optimization metric.

Baseline hyperparameters¶

hidden_size: 64
learning_rate: 0.001
batch_size: 16
dropout: 0.2

CRPS = 6.44

Tuned hyperparamters¶

hidden_size: 78
learning_rate: 0.00365
batch_size: 8
dropout: 0.477

CRPS = 6.02

Comments¶

Tuning was able to significantly improve the CRPS score primarily by decreasing the width of the prediction spread across the dataset. It also slightly decreased the expected value residuals in intermediate DAM price divergences giving better fit.

However it introduced some additional coverage loss and bias (particularly on the low price prediction end) but for CRPS it more than balanced this by the spread reduction. Whether the improvement in CRPS translates to better model performance would depend on the relative importance of coverage versus prediction width in the battery dispatch optimization algorithm.

The hyperparameter tuning significantly increased the dropout rate and decreased the batch size implying that model overfitting is likely a problem. Better results can be obtained by forcing the model to learn large scale patterns instead of point to point noise. Even still in the tuned model some oscillatory artifacts were introduced and can be seen in the March 8-9 prediction.

This model, in addition to the baseline, still has trouble prediction strong outlier price spikes. The events are too rare and extreme to give the model incentive to account for them.

Improvements¶

  • It's likely that this model needs more data. Both from different nodes in order to have more examples of daily price signals, as well as more diverse data including generation mix, weather forecasts, and congestion information.

  • Moving from a single LSTM model to an encoder-decoder architecture. Right now the model only predicts one point at a time. Moving to a system where an entire day input can be used to create an entire day output would likely lead to much better capturing of daily patterns.

  • A separate model should be trained specifically on predicting price spikes and correcting extreme outlier bias. This model's expected value residuals could be used as the loss function.

In [1]:
from datetime import datetime

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly
import plotly.colors as pc
import plotly.graph_objects as go
import yaml

from price_forecasting.config import MODELS_DIR, PROCESSED_DATA_DIR
from price_forecasting.utils.scoring_tools import get_mean_crps

Loading Data¶

In [2]:
MODEL_DIR = MODELS_DIR / 'LSTM_tuning'
BASELINE_MODEL_DIR = MODELS_DIR / 'LSTM_v1'

model_label = "Tuned"
In [3]:
# load config file and variables from model run
with open(MODEL_DIR / 'config.yaml', 'r') as f:
    config = yaml.safe_load(f)
DATA_SOURCE = PROCESSED_DATA_DIR / config['data_source']
quantiles = np.array(config['quantiles'])

# load X data in DataFrame format
X_test = pd.read_parquet(DATA_SOURCE / 'X_test.pqt')

# load y_test data and format as numpy array
y_test = pd.read_parquet(DATA_SOURCE / 'y_test.pqt').to_numpy()
y_test = y_test.reshape([-1])

# load y predictions from model
y_pred = np.load(MODEL_DIR / 'y_pred.npy')

# baseline predictions loading
y_pred_baseline = np.load(BASELINE_MODEL_DIR / 'y_pred.npy')

dam = X_test['DAM_PRC'] #day ahead price data

time = X_test.index
time = time.tz_convert('US/Pacific')

Plotting¶

In [4]:
plotly.offline.init_notebook_mode()

def quantile_plot():
    fig = go.Figure()

    n = len(quantiles)
    colors = pc.sample_colorscale('Viridis', [i/n for i in range(n)])

    for i, yi in enumerate(y_pred.T):
        if i == 0: 
            fill = None
        else:
            fill = 'tonexty'
        fig.add_trace(go.Scatter(x=time, y=yi + dam, mode='lines', name=quantiles[i], 
                                 line=dict(width=1,color=colors[i]), fill=fill))

    fig.add_trace(go.Scatter(x=time, y=y_test + dam, mode='lines', name='RTM Price',
                              line=dict(width=2, color='black', dash='dot')))

    start = datetime.fromisoformat('2025-03-01 00:00:00')
    end = datetime.fromisoformat('2025-03-08 00:00:00')

    fig.update_layout(title='SP15 RTM Price Prediction',
                      xaxis_title='Time',
                      width=1100, 
                      height=500,
                      yaxis_title='Price ($/MWh)',
                      xaxis=dict(range=[start, end]),   
                      yaxis=dict(range=[-100, 250]), 
                     )

    fig.show()

quantile_plot()

Model Quantification¶

CRPS¶

In [5]:
crps = get_mean_crps(y_pred, y_test, quantiles)
print(f"CRPS: {crps:.{3}}")
CRPS: 6.04

Coverage¶

In [6]:
def interval_violation_plot():
    fig = go.Figure()

    q_hi = 0.99
    q_low = 0.01

    i_low = np.where(quantiles == q_low)[0][0]
    y_pred_low = y_pred[:,i_low]
    under = (y_test < y_pred_low).astype(float) * -200

    i_hi = np.where(quantiles == q_hi)[0][0]
    y_pred_hi = y_pred[:,i_hi]
    over = (y_test > y_pred_hi).astype(float) * 200

    fig.add_trace(go.Scatter(x=time, y=y_pred_low + dam, mode='lines', name=q_low, 
                            line=dict(width=1,color="gray"), fill=None))


    fig.add_trace(go.Scatter(x=time, y=y_pred_hi + dam, mode='lines', name=q_hi, 
                            line=dict(width=1,color="gray"), fill="tonexty"))

    fig.add_trace(go.Scatter(x=time, y=y_test + dam, mode='lines', name='RTM Price',
                          line=dict(width=2, color='black', dash='dot')))

    fig.add_trace(go.Scatter(x=time, y=under, mode='lines', name='Under', 
                            line=dict(width=0,color="blue"), fill="tozeroy"))

    fig.add_trace(go.Scatter(x=time, y=over, mode='lines', name='Over', 
                            line=dict(width=0,color="red"), fill="tozeroy"))


    start = datetime.fromisoformat('2025-03-01 00:00:00')
    end = datetime.fromisoformat('2025-03-08 00:00:00')

    fig.update_layout(title='1-99% Interval Violation Plot',
                      xaxis_title='Time',
                      width=1000, 
                      height=500,
                      yaxis_title='Price ($/MWh)',
                      xaxis=dict(range=[start, end]),   
                      yaxis=dict(range=[-100, 250]), 
                     )

    fig.show()
interval_violation_plot()
In [7]:
fig = go.Figure()

def coverage_calibration_plot(y_pred, label, fig):
    coverage = []
    for i, q in enumerate(quantiles):
        below = (y_test <= y_pred[:,i]).mean()
        coverage.append(below)

    fig.add_trace(go.Scatter(x=quantiles, y=quantiles - coverage, mode="markers", name=label,
                              marker=dict(size=8)))


coverage_calibration_plot(y_pred, model_label, fig)
coverage_calibration_plot(y_pred_baseline, "Baseline", fig)

fig.add_trace(go.Scatter(x=[0,1], y=[0,0], mode='lines', name='Expected',
                          line=dict(width=2, color='black', dash='dot')))

fig.update_layout(title='Coverage Calibration',
                 xaxis_title='Quantile',
                 width=650, 
                 height=500,
                 yaxis_title='Coverage Bias',
                 yaxis=dict(range=[-0.05, 0.05]), 
                 xaxis=dict(range=[0.0, 1.0]), 
                )
fig.show()

Expected Value Residuals¶

In [8]:
fig = go.Figure()

def residual_plot(y_pred, label, fig):
    evs = [] #expected value from quantiles
    for pred in y_pred:
        ev = np.trapezoid(pred, quantiles)
        evs.append(ev)
    evs = np.array(evs)

    residual = y_test - evs

    fig.add_trace(go.Scatter(x=y_test, y=residual, mode="markers", name=label,
                              marker=dict(size=3,opacity=0.5)))

residual_plot(y_pred, model_label, fig)
residual_plot(y_pred_baseline, "Baseline", fig)

fig.add_trace(go.Scatter(x=[-100,1000], y=[0,0], mode='lines', name='Expected',
                          line=dict(width=2, color='black', dash='dot')))

fig.update_layout(title='Expected Value Residuals',
              xaxis_title='True Value',
              width=1000, 
              height=500,
              yaxis_title='Residual',
              xaxis=dict(range=[-100, 250]),   
              yaxis=dict(range=[-100, 250]), 
             )
fig.show()

Probability Integral Transform¶

In [9]:
def PIT_plot(y_pred, label, alpha=1):
    PIT = []
    for val, pred in zip(y_test, y_pred):
        p = np.interp(val, pred, quantiles, left=0.0, right=1.0)
        PIT.append(p)

    plt.hist(PIT, bins=4, density=True, label=label, alpha=alpha)

PIT_plot(y_pred, model_label)
PIT_plot(y_pred_baseline, "Baseline", alpha=0.5)

plt.axhline(1, color='k', linestyle='--')
plt.xlabel('Quartile Bin')
plt.ylabel('Probability Density')
plt.title("Probability Integral Transform")
plt.legend()
plt.show()
No description has been provided for this image

Quantile Spread¶

In [10]:
# Measure prediction spread
fig = go.Figure()

def spread_plot(y_pred, label, fig, alpha=1):

    q_hi = 0.95
    q_low = 0.05

    i_low = np.where(quantiles == q_low)[0][0]
    y_pred_low = y_pred[:,i_low]

    i_hi = np.where(quantiles == q_hi)[0][0]
    y_pred_hi = y_pred[:,i_hi]

    spread = y_pred_hi - y_pred_low
    avg_width = spread.mean()
    print(f'Average 5-95% Prediction Width: {avg_width:.{3}} $/MWh')

    fig.add_trace(go.Scatter(x=time, y=spread, mode='lines', name=label+" Spread", 
                            line=dict(width=2), fill=None))


    plt.hist(spread, density=True, label=label, alpha=alpha)

spread_plot(y_pred, model_label, fig)
spread_plot(y_pred_baseline, "Baseline", fig, alpha=0.5)


fig.add_trace(go.Scatter(x=time, y=y_test + dam, mode='lines', name='RTM Price',
                      line=dict(width=1, color='black', dash='dot')))


start = datetime.fromisoformat('2025-03-01 00:00:00')
end = datetime.fromisoformat('2025-03-08 00:00:00')

fig.update_layout(title='5-95% Spread Measure',
                  xaxis_title='Time',
                  width=1000, 
                  height=500,
                  yaxis_title='Price ($/MWh)',
                  xaxis=dict(range=[start, end]),   
                  yaxis=dict(range=[-100, 250]), 
                 )
fig.show()

plt.xlabel('Spread ($/MWh)')
plt.ylabel('Probability Density')
plt.title("Spread Distribution")
plt.legend()
plt.show()
Average 5-95% Prediction Width: 30.6 $/MWh
Average 5-95% Prediction Width: 33.4 $/MWh
No description has been provided for this image
In [ ]: